{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.nn.parameter import Parameter\n",
    "from torch.nn.modules.module import Module\n",
    "\n",
    "import dgl\n",
    "import dgl.function as fn\n",
    "\n",
    "import dgl.data\n",
    "import networkx as nx\n",
    "\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchdyn.nn import DataControl, DepthCat, Augmenter\n",
    "from torchdyn.datasets import *\n",
    "from torchdyn.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# quick run for automated notebook validation\n",
    "dry_run = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural Graph Differential Equations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Semi-supervised node classification "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook introduces `Neural GDEs` as a general high-performance model for graph structured data. Notebook `07_graph_differential_equations` is designed from the ground up as an introduction to Neural GDEs and therefore contains ample comments to provide insights on some of our design choices. To be accessible to practicioners/researchers without prior experience on GNNs, we discuss some features of `dgl` as well, one of the PyTorch ecosystems for geometric deep learning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "# seed for repeatability\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "# dgl offers convenient access to GNN benchmark datasets via `dgl.data`...\n",
    "# other standard datasets (e.g. Citeseer / Pubmed) are also accessible via the dgl.data\n",
    "# API. The rest of the notebook is compatible with Cora / Citeseer / Pubmed with minimal\n",
    "# modification required.\n",
    "data = dgl.data.CoraGraphDataset()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(7, 140, 500, 1000)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Cora is a node-classification datasets with 2708 nodes\n",
    "X = data.ndata['feat'].to(device)\n",
    "Y = data.ndata['label'].to(device)\n",
    "\n",
    "# In transductive semi-supervised node classification tasks on graphs, the model has access to all\n",
    "# node features but only a masked subset of the labels\n",
    "train_mask = data.ndata['train_mask']\n",
    "val_mask = data.ndata['val_mask']\n",
    "test_mask = data.ndata['test_mask']\n",
    "\n",
    "num_feats = X.shape[1]\n",
    "n_classes = 7\n",
    "\n",
    "# 140 training samples, 300 validation, 1000 test\n",
    "n_classes, train_mask.sum().item(), val_mask.sum().item(),test_mask.sum().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.remove_edges()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "'bool' object is not callable",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-35-e9e5cabd9387>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# add self-edge for each node\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mremove_edges\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselfloop_edges\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_edges_from\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnodes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnodes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDGLGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.cache/pypoetry/virtualenvs/torchdyn-voYSR01p-py3.8/lib/python3.8/site-packages/networkx/classes/function.py\u001b[0m in \u001b[0;36mselfloop_edges\u001b[0;34m(G, data, keys, default)\u001b[0m\n\u001b[1;32m   1194\u001b[0m             )\n\u001b[1;32m   1195\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1196\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_multigraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1197\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mkeys\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1198\u001b[0m                 return (\n",
      "\u001b[0;31mTypeError\u001b[0m: 'bool' object is not callable"
     ]
    }
   ],
   "source": [
    "# add self-edge for each node\n",
    "g = data\n",
    "g.remove_edges(nx.selfloop_edges(g))\n",
    "g.add_edges_from(zip(g.nodes(), g.nodes()))\n",
    "g = dgl.DGLGraph(g)\n",
    "edges = g.edges()\n",
    "n_edges = g.number_of_edges()\n",
    "\n",
    "n_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute diagonal of normalization matrix D according to standard formula\n",
    "degs = g.in_degrees().float()\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm[torch.isinf(norm)] = 0\n",
    "# add to dgl.Graph in order for the norm to be accessible at training time\n",
    "g.ndata['norm'] = norm.unsqueeze(1).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Neural GCDE "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As Neural ODEs, GDEs require specification of an ODE function (`ODEFunc`), representing the set of layers that will be called repeatedly by the ODE solver.\n",
    "\n",
    "Here, we use the convolutional variant of Neural GDEs based on GCNs: `Neural GCDEs`. The only difference with alternative neural GDEs resides in the type of GNN layer utilized in the ODEFunc.\n",
    "\n",
    "For adaptive step GDEs (dopri5) we increase the hidden dimension to 64 to reduce the stiffness of the ODE and therefore the number of ODEFunc evaluations (`NFE`: Number Function Evaluation)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we define the auxiliary GNN model as a standard GCN. Luckily, in this example the graph is static and can thus be assigned during initialization. For varying graphs, additional bookeeping is required."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(y_hat:torch.Tensor, y:torch.Tensor):\n",
    "    preds = torch.max(y_hat, 1)[1]\n",
    "    return torch.mean((y == preds).float())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCNLayer(nn.Module):\n",
    "    def __init__(self, g:dgl.DGLGraph, in_feats:int, out_feats:int, activation,\n",
    "                 dropout:int, bias:bool=True):\n",
    "        super().__init__()\n",
    "        self.g = g\n",
    "        self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))\n",
    "        if bias:\n",
    "            self.bias = nn.Parameter(torch.Tensor(out_feats))\n",
    "        else:\n",
    "            self.bias = None\n",
    "        self.activation = activation\n",
    "        if dropout:\n",
    "            self.dropout = nn.Dropout(p=dropout)\n",
    "        else:\n",
    "            self.dropout = 0.\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        stdv = 1. / math.sqrt(self.weight.size(1))\n",
    "        self.weight.data.uniform_(-stdv, stdv)\n",
    "        if self.bias is not None:\n",
    "            self.bias.data.uniform_(-stdv, stdv)\n",
    "\n",
    "    def forward(self, h):\n",
    "        if self.dropout:\n",
    "            h = self.dropout(h)\n",
    "        h = torch.mm(h, self.weight)\n",
    "        # normalization by square root of src degree\n",
    "        h = h * self.g.ndata['norm']\n",
    "        self.g.ndata['h'] = h\n",
    "        self.g.update_all(fn.copy_src(src='h', out='m'),\n",
    "                          fn.sum(msg='m', out='h'))\n",
    "        h = self.g.ndata.pop('h')\n",
    "        # normalization by square root of dst degree\n",
    "        h = h * self.g.ndata['norm']\n",
    "        # bias\n",
    "        if self.bias is not None:\n",
    "            h = h + self.bias\n",
    "        if self.activation:\n",
    "            h = self.activation(h)\n",
    "        return h"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we construct the Neural GDE as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "func = nn.Sequential(GCNLayer(g=g, in_feats=64, out_feats=64, activation=nn.Softplus(), dropout=0.9),\n",
    "                     GCNLayer(g=g, in_feats=64, out_feats=64, activation=None, dropout=0.9)\n",
    "                     ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "neuralDE = NeuralODE(func, solver='rk4', s_span=torch.linspace(0, 1, 3)).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = nn.Sequential(GCNLayer(g=g, in_feats=num_feats, out_feats=64, activation=None, dropout=0.4),\n",
    "                  neuralDE,\n",
    "                  GCNLayer(g=g, in_feats=64, out_feats=n_classes, activation=None, dropout=0.)\n",
    "                  ).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PerformanceContainer(object):\n",
    "    \"\"\" Simple data class for metrics logging.\"\"\"\n",
    "    def __init__(self, data:dict):\n",
    "        self.data = data\n",
    "        \n",
    "    @staticmethod\n",
    "    def deep_update(x, y):\n",
    "        for key in y.keys():\n",
    "            x.update({key: list(x[key] + y[key])})\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = torch.optim.Adam(m.parameters(), lr=1e-3, weight_decay=5e-4)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "logger = PerformanceContainer(data={'train_loss':[], 'train_accuracy':[],\n",
    "                                   'test_loss':[], 'test_accuracy':[],\n",
    "                                   'forward_time':[], 'backward_time':[],\n",
    "                                   })\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[150], Loss: 1.457, Train Accuracy: 0.514, Test Accuracy: 0.377\n",
      "[300], Loss: 0.730, Train Accuracy: 0.907, Test Accuracy: 0.731\n",
      "[450], Loss: 0.542, Train Accuracy: 0.921, Test Accuracy: 0.766\n",
      "[600], Loss: 0.416, Train Accuracy: 0.950, Test Accuracy: 0.816\n",
      "[750], Loss: 0.557, Train Accuracy: 0.943, Test Accuracy: 0.810\n",
      "[900], Loss: 0.353, Train Accuracy: 0.964, Test Accuracy: 0.819\n",
      "[1050], Loss: 0.265, Train Accuracy: 0.971, Test Accuracy: 0.807\n",
      "[1200], Loss: 0.340, Train Accuracy: 0.964, Test Accuracy: 0.828\n",
      "[1350], Loss: 0.201, Train Accuracy: 0.971, Test Accuracy: 0.828\n",
      "[1500], Loss: 0.368, Train Accuracy: 0.971, Test Accuracy: 0.824\n",
      "[1650], Loss: 0.255, Train Accuracy: 0.979, Test Accuracy: 0.812\n",
      "[1800], Loss: 0.241, Train Accuracy: 0.971, Test Accuracy: 0.820\n",
      "[1950], Loss: 0.304, Train Accuracy: 0.979, Test Accuracy: 0.821\n",
      "[2100], Loss: 0.248, Train Accuracy: 0.971, Test Accuracy: 0.828\n",
      "[2250], Loss: 0.223, Train Accuracy: 0.979, Test Accuracy: 0.815\n",
      "[2400], Loss: 0.180, Train Accuracy: 0.979, Test Accuracy: 0.834\n",
      "[2550], Loss: 0.321, Train Accuracy: 0.986, Test Accuracy: 0.825\n",
      "[2700], Loss: 0.166, Train Accuracy: 0.986, Test Accuracy: 0.808\n",
      "[2850], Loss: 0.171, Train Accuracy: 0.986, Test Accuracy: 0.821\n",
      "[3000], Loss: 0.190, Train Accuracy: 0.986, Test Accuracy: 0.827\n",
      "[3150], Loss: 0.207, Train Accuracy: 0.993, Test Accuracy: 0.823\n",
      "[3300], Loss: 0.159, Train Accuracy: 0.986, Test Accuracy: 0.817\n",
      "[3450], Loss: 0.183, Train Accuracy: 0.993, Test Accuracy: 0.829\n",
      "[3600], Loss: 0.161, Train Accuracy: 0.986, Test Accuracy: 0.831\n",
      "[3750], Loss: 0.143, Train Accuracy: 0.986, Test Accuracy: 0.826\n",
      "[3900], Loss: 0.182, Train Accuracy: 0.986, Test Accuracy: 0.826\n",
      "[4050], Loss: 0.156, Train Accuracy: 0.993, Test Accuracy: 0.817\n",
      "[4200], Loss: 0.177, Train Accuracy: 0.986, Test Accuracy: 0.819\n",
      "[4350], Loss: 0.160, Train Accuracy: 0.993, Test Accuracy: 0.811\n",
      "[4500], Loss: 0.182, Train Accuracy: 0.986, Test Accuracy: 0.829\n",
      "[4650], Loss: 0.130, Train Accuracy: 0.986, Test Accuracy: 0.812\n",
      "[4800], Loss: 0.143, Train Accuracy: 0.986, Test Accuracy: 0.830\n",
      "[4950], Loss: 0.207, Train Accuracy: 0.993, Test Accuracy: 0.818\n"
     ]
    }
   ],
   "source": [
    "steps = 5000\n",
    "verbose_step = 150\n",
    "num_grad_steps = 0\n",
    "\n",
    "for i in range(steps): # looping over epochs\n",
    "    m.train()\n",
    "    outputs = m(X)\n",
    "    y_pred = outputs\n",
    "    loss = criterion(y_pred[train_mask], Y[train_mask])\n",
    "    opt.zero_grad()\n",
    "    \n",
    "    start_time = time.time()\n",
    "    loss.backward()\n",
    "    \n",
    "    opt.step()\n",
    "    num_grad_steps += 1\n",
    "\n",
    "    with torch.no_grad():\n",
    "        m.eval()\n",
    "\n",
    "        # calculating outputs again with zeroed dropout\n",
    "        y_pred = m(X)\n",
    "\n",
    "        train_loss = loss.item()\n",
    "        train_acc = accuracy(y_pred[train_mask], Y[train_mask]).item()\n",
    "        test_acc = accuracy(y_pred[test_mask], Y[test_mask]).item()\n",
    "        test_loss = criterion(y_pred[test_mask], Y[test_mask]).item()\n",
    "        logger.deep_update(logger.data, dict(train_loss=[train_loss], train_accuracy=[train_acc],\n",
    "                           test_loss=[test_loss], test_accuracy=[test_acc])\n",
    "                          )\n",
    "\n",
    "    if num_grad_steps % verbose_step == 0:\n",
    "        print('[{}], Loss: {:3.3f}, Train Accuracy: {:3.3f}, Test Accuracy: {:3.3f}'.format(num_grad_steps,\n",
    "                                                                                                    train_loss,\n",
    "                                                                                                    train_acc,\n",
    "                                                                                                    test_acc,\n",
    "                                                                                                    ))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchdyn",
   "language": "python",
   "name": "torchdyn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
